{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 进阶RAG检索   Contextual Compression + Filtering  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 前期准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain.document_loaders import PyPDFLoader\n",
    "from langchain.embeddings.openai import OpenAIEmbeddings\n",
    "from langchain.vectorstores import Chroma\n",
    "\n",
    "\n",
    "# Load pdf\n",
    "loader = PyPDFLoader(\"E:\\\\langchain_RAG\\\\data\\\\baichuan.pdf\")\n",
    "data = loader.load()\n",
    "\n",
    "# Split\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
    "splits = text_splitter.split_documents(data[:6])\n",
    "\n",
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "OPENAI_API_KEY = getpass()\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n",
    "\n",
    "# VectorDB\n",
    "embedding = OpenAIEmbeddings()\n",
    "vectordb = Chroma.from_documents(documents=splits, embedding=embedding)\n",
    "\n",
    "retriever = vectordb.as_retriever()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_docs = retriever.get_relevant_documents(\n",
    "    \"What is baichuan2 ？\"\n",
    ")\n",
    "\n",
    "base_docs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Contextual Extractor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.retrievers import ContextualCompressionRetriever\n",
    "from langchain.retrievers.document_compressors import LLMChainExtractor\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "\n",
    "llm = ChatOpenAI(temperature=0)\n",
    "compressor = LLMChainExtractor.from_llm(llm)\n",
    "compression_retriever = ContextualCompressionRetriever(\n",
    "    base_compressor=compressor, base_retriever=retriever\n",
    ")\n",
    "\n",
    "compressed_docs = compression_retriever.get_relevant_documents(\n",
    "    \"What is baichuan2 ？\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compressed_docs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Contextual filters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### LLMChainFilter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.retrievers.document_compressors import LLMChainFilter\n",
    "\n",
    "_filter = LLMChainFilter.from_llm(llm)\n",
    "compression_retriever = ContextualCompressionRetriever(\n",
    "    base_compressor=_filter, base_retriever=retriever\n",
    ")\n",
    "\n",
    "compressed_docs = compression_retriever.get_relevant_documents(\n",
    "    \"What is baichuan2 ？\"\n",
    ")\n",
    "\n",
    "compressed_docs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### EmbeddingsFilter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.retrievers.document_compressors import EmbeddingsFilter\n",
    "\n",
    "embeddings_filter = EmbeddingsFilter(embeddings=embedding, similarity_threshold=0.76)\n",
    "compression_retriever = ContextualCompressionRetriever(\n",
    "    base_compressor=embeddings_filter, base_retriever=retriever\n",
    ")\n",
    "\n",
    "compressed_docs = compression_retriever.get_relevant_documents(\n",
    "    \"What is baichuan2 ？\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compressed_docs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
